package ci.web.codec;

import ci.web.router.CiWebSocketHandler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AttributeKey;

/**
 * websocket-ext-handler
 * @author zhh
 */
@ChannelHandler.Sharable
public abstract class CiHttpWebSocketHandler extends CiHttpHandler {

    private String webSocketPath = "/websocket";
    private CiWebSocketHandler wsHandler;
    
    public CiHttpWebSocketHandler(CiWebSocketHandler wsHandler, int limit) {
        this(wsHandler, "/websocket", limit);
    }
    public CiHttpWebSocketHandler(CiWebSocketHandler wsHandler, String wsPath, int limit) {
        super(limit);
        this.wsHandler = wsHandler;
        this.webSocketPath = wsPath;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg)
            throws Exception {
        if (msg instanceof WebSocketFrame){
            onData(ctx.channel(), (WebSocketFrame)msg);
        }else{
            if ((msg instanceof FullHttpRequest)){
                FullHttpRequest req = (FullHttpRequest) msg;
                if(req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_VERSION)){
                    String path = (ctx.pipeline().get(SslHandler.class)==null?"ws://":"wss://") + req.headers().getAsString(HttpHeaderNames.HOST) + webSocketPath ;
                    WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(path, null, false);
                    final WebSocketServerHandshaker hs = wsFactory.newHandshaker(req);
                    if (hs == null) {
                          WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()).addListener(ChannelFutureListener.CLOSE);
                    } else {
                        hs.handshake(ctx.channel(), req).addListener(new ChannelFutureListener() {
                            @Override
                            public void operationComplete(ChannelFuture f) throws Exception {
                                if(f.isSuccess()){
                                    f.channel().attr(WEBSOCKET_KEY).set(hs);
                                    onOpen(f.channel());
                                }else{
                                    f.channel().pipeline().fireExceptionCaught(f.cause());
                                }
                            }
                        });
                    }
                    return;
                }
            }
            super.channelRead0(ctx, msg);
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        if(ctx.channel().attr(WEBSOCKET_KEY).getAndRemove()!=null){
            onClose(ctx.channel(), true);
        }
    }
    private void onOpen(Channel channel) {
        wsHandler.onOpen(channel);
    }
    private void onClose(Channel channel, boolean lost) {
        wsHandler.onClose(channel, lost);
    }
    private void onData(Channel channel, WebSocketFrame frame) {
        if(frame instanceof TextWebSocketFrame){
            wsHandler.onText(channel, ((TextWebSocketFrame)frame).text());
        }else if(frame instanceof BinaryWebSocketFrame){
            BinaryWebSocketFrame b = (BinaryWebSocketFrame)frame;
            byte[] bytes = new byte[b.content().readableBytes()];
            b.content().readBytes(bytes);
            wsHandler.onBytes(channel, bytes);
        }else if(frame instanceof CloseWebSocketFrame){
            WebSocketServerHandshaker hs = (WebSocketServerHandshaker) channel.attr(WEBSOCKET_KEY).getAndRemove();
            if(hs!=null){
                hs.close(channel, (CloseWebSocketFrame)frame.retain()).addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        onClose(future.channel(), false);
                    }
                });
            }
        }
    }

    private static final AttributeKey<Object> WEBSOCKET_KEY = AttributeKey.valueOf("WEBSOCKET_KEY");

}
